#include "utils.h"
#include <algorithm>
#include <assert.h>
#include <cstdint>
#include <string.h>
#include <unistd.h>
#include <vector>
#include <zlib.h>

// Write a uint8_t to starting address &buf[pos].
size_t WriteBuf::write_uint8_at(size_t pos, uint8_t x) {
  assert(bufsize_ >= pos);
  assert(bufsize_ - pos >= sizeof(uint8_t));
  buf_[pos++] = x;
  return pos;
}

// Write a uint16_t to starting address &buf_[pos] (in big endian
// order).
size_t WriteBuf::write_uint16_at(size_t pos, uint16_t x) {
  assert(bufsize_ >= pos);
  assert(bufsize_ - pos >= sizeof(uint16_t));
  buf_[pos++] = (x >> 8) & 0xFF;
  buf_[pos++] = x & 0xFF;
  return pos;
}

// Write a uint32_t to starting address &buf_[pos] (in big endian
// order).
size_t WriteBuf::write_uint32_at(size_t pos, uint32_t x) {
  assert(bufsize_ >= pos);
  assert(bufsize_ - pos >= sizeof(uint32_t));
  buf_[pos++] = (x >> 24) & 0xFF;
  buf_[pos++] = (x >> 16) & 0xFF;
  buf_[pos++] = (x >> 8) & 0xFF;
  buf_[pos++] = x & 0xFF;
  return pos;
}

// Write a block of bytes to starting address &buf_[pos].
size_t WriteBuf::write_many_at(size_t pos, uint8_t x, size_t n) {
  assert(bufsize_ >= pos);
  assert(bufsize_ - pos >= n);
  memset(&buf_[pos], x, n);
  pos += n;
  return pos;
}

// Write a string to starting address &buf_[pos].
size_t WriteBuf::write_bytes_at(size_t pos, const uint8_t *bytes, size_t n) {
  assert(bufsize_ >= pos);
  assert(bufsize_ - pos >= n);
  memcpy(&buf_[pos], bytes, n);
  pos += n;
  return pos;
}

// Write a uint8_t to starting address &buf[offset_].
void WriteBuf::write_uint8(uint8_t x) { offset_ = write_uint8_at(offset_, x); }

// Write a uint16_t to starting address &buf_[offset_] (in big endian
// order).
void WriteBuf::write_uint16(uint16_t x) {
  offset_ = write_uint16_at(offset_, x);
}

// Write a uint32_t to starting address &buf_[offset_] (in big endian
// order).
void WriteBuf::write_uint32(uint32_t x) {
  offset_ = write_uint32_at(offset_, x);
}

// Write a block of bytes to starting address &buf_[offset_].
void WriteBuf::write_many(uint8_t x, size_t n) {
  offset_ = write_many_at(offset_, x, n);
}

// Write n bytes to starting address &buf_[offset_].
void WriteBuf::write_bytes(const uint8_t *bytes, size_t n) {
  offset_ = write_bytes_at(offset_, bytes, n);
}

// Write a string to starting address &buf_[offset_].
void WriteBuf::write_string(const char *str) {
  write_bytes(reinterpret_cast<const uint8_t *>(str), strlen(str));
}

size_t WriteBuf::offset() const { return offset_; }

// Inserts an n-byte gap, so that the bytes can be written later. This is
// usually used for size or offset fields that need to be calculated
// later.
size_t WriteBuf::insert_gap(size_t n) {
  const size_t pos = offset_;
  assert(bufsize_ >= pos);
  assert(bufsize_ - pos >= n);
  offset_ = pos + n;
  return pos;
}

void WriteBuf::write_to_fd(int fd) { write(fd, buf_, offset_); }

int compress(std::vector<uint8_t> &output, std::vector<uint8_t> &input) {
  int ret;
  z_stream strm;

  strm.zalloc = nullptr;
  strm.zfree = nullptr;
  strm.opaque = nullptr;

  ret = deflateInit(&strm, Z_BEST_COMPRESSION);
  if (ret != Z_OK)
    return ret;

  size_t input_pos = 0;
  size_t output_pos = 0;

  strm.avail_in = input.size();
  strm.next_in = input.data();
  output.resize(0x10000);

  while (input_pos < input.size() || strm.avail_out == 0) {
    assert(input_pos <= input.size());
    assert(output_pos <= output.size());
    if (output_pos == output.size()) {
      output.resize(output.size() * 2);
    }

    const size_t total_avail_in = input.size() - input_pos;
    const size_t avail_in = std::min<size_t>(0x10000, total_avail_in);
    const int flush = avail_in < total_avail_in ? Z_NO_FLUSH : Z_FINISH;
    strm.avail_in = avail_in;
    strm.next_in = input.data() + input_pos;
    strm.avail_out = output.size() - output_pos;
    strm.next_out = output.data() + output_pos;
    ret = deflate(&strm, flush);
    assert(ret != Z_STREAM_ERROR);
    output_pos = output.size() - strm.avail_out;
    input_pos += avail_in - strm.avail_in;
  }
  assert(strm.avail_in == 0);
  assert(ret == Z_STREAM_END);
  output.resize(output.size() - strm.avail_out);

  (void)deflateEnd(&strm);
  return Z_OK;
}
